{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "md-title",
   "metadata": {},
   "source": [
    "# Bayesian Synthetic Control: Chile's Democratic Transition (1990)\n",
    "\n",
    "Estimates the causal effect of Chile's 1990 return to democracy on GDP per capita\n",
    "using a Bayesian synthetic control (CausalPy + PyMC)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-setup",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import arviz as az\n",
    "import causalpy as cp\n",
    "import os\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-data",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_panel(file_path='BBDD_input/base_4dic.csv'):\n",
    "    df = pd.read_csv(file_path)\n",
    "    df = df[df['year'] <= 2009]\n",
    "\n",
    "    # Countries excluded following Abadie et al. (2003) donor-pool criteria\n",
    "    exclude = ['China', 'Philippines', 'Canada',\n",
    "               'United States', 'South Africa', 'Australia']\n",
    "    df = df[~df['country_name'].isin(exclude)]\n",
    "\n",
    "    df['Time'] = pd.to_datetime(df['year'].astype(str) + '-01-01')\n",
    "\n",
    "    panel = df.pivot_table(\n",
    "        index='Time', columns='country_name',\n",
    "        values='gdp_percapita', aggfunc='mean'\n",
    "    )\n",
    "    # Drop countries with more than 30 % missing observations\n",
    "    panel = panel.dropna(axis=1, thresh=int(len(panel) * 0.7))\n",
    "    panel = panel.interpolate('linear').dropna()\n",
    "    return panel\n",
    "\n",
    "\n",
    "panel = load_panel()\n",
    "\n",
    "TREATMENT_TIME = pd.to_datetime('1990-01-01')\n",
    "TARGET   = 'Chile'\n",
    "CONTROLS = [c for c in panel.columns if c != TARGET]\n",
    "\n",
    "print(f\"Panel: {panel.shape[0]} periods x {panel.shape[1]} countries \"\n",
    "      f\"({panel.index.year.min()}-{panel.index.year.max()})\")\n",
    "print(f\"Treated unit : {TARGET}\")\n",
    "print(f\"Control pool : {len(CONTROLS)} countries\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-eda",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalized GDP per capita (intervention year = 100)\n",
    "norm = panel.div(panel.loc[TREATMENT_TIME]) * 100\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(11, 5))\n",
    "for c in CONTROLS:\n",
    "    ax.plot(norm.index, norm[c], color='grey', alpha=0.3, lw=1)\n",
    "ax.plot(norm.index, norm[TARGET], color='#D62728', lw=2.5, label='Chile')\n",
    "ax.axvline(TREATMENT_TIME, color='black', ls='--', lw=1,\n",
    "           label='Intervention (1990)')\n",
    "ax.set_ylabel('GDP per capita (1990 = 100)')\n",
    "ax.set_title('Chile vs. control pool – normalized GDP per capita')\n",
    "ax.legend()\n",
    "ax.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-model",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_kwargs = dict(\n",
    "    tune=3000, draws=1000,\n",
    "    target_accept=0.95,\n",
    "    random_seed=42,\n",
    "    chains=4, cores=4,\n",
    ")\n",
    "\n",
    "result = cp.SyntheticControl(\n",
    "    data=panel,\n",
    "    treatment_time=TREATMENT_TIME,\n",
    "    control_units=CONTROLS,\n",
    "    treated_units=[TARGET],\n",
    "    model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-diagnostics",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = az.summary(result.idata, var_names=['~mu'])\n",
    "print(summary.round(3))\n",
    "\n",
    "rhat_ok = (summary['r_hat'] <= 1.01).all()\n",
    "ess_ok  = (summary['ess_bulk'] >= 200).all()\n",
    "print(f\"R-hat <= 1.01 : {'pass' if rhat_ok else 'FAIL'}  |  \"\n",
    "      f\"ESS >= 200 : {'pass' if ess_ok else 'FAIL'}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-traces",
   "metadata": {},
   "outputs": [],
   "source": [
    "az.plot_trace(result.idata, var_names=['beta', 'y_hat_sigma'],\n",
    "              compact=False, figsize=(14, max(8, len(CONTROLS) * 1.2)))\n",
    "plt.suptitle('MCMC traces – convergence diagnostics', y=1.01)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-results",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Observed vs. synthetic control, gap, and causal effect\n",
    "fig, axes = result.plot(treated_unit=TARGET, figsize=(12, 10))\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-weights",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Posterior distribution of control-country weights\n",
    "fig, ax = plt.subplots(figsize=(8, max(5, len(CONTROLS) * 0.4)))\n",
    "az.plot_forest(result.idata, var_names=['beta'], ax=ax)\n",
    "ax.set_title('Posterior weights – control countries')\n",
    "ax.set_xlabel('Weight')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "beta_df = (\n",
    "    az.summary(result.idata, var_names=['beta'])\n",
    "    .sort_values('mean', ascending=False)\n",
    ")\n",
    "print('Top control countries by posterior mean weight:')\n",
    "print(beta_df[['mean', 'hdi_3%', 'hdi_97%']].head(8).round(3))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}